import logging
import random
import os
import numpy as np
import torch
from omegaconf import OmegaConf


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def get_logger(cfg, name=None):
    # log_file_path is used when unit testing
    logging.config.dictConfig(
            OmegaConf.to_container(cfg.job_logging_cfg, resolve=True)
        )     
    return logging.getLogger(name)


def add_weight_decay(net, l2_value, skip_list=()):
    decay, no_decay = [], []
    for name, param in net.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': l2_value}]



def save_array_to_npy(array, filename_prefix, cfg):
    save_filename = f"{cfg.mdl_name}_{cfg.dataset.type}_{cfg.dataset.num}_{filename_prefix}"
    save_path = os.path.join(cfg.log.ATE, save_filename)
    np.save(save_path, array)

    
def causal_mae(pre, tar) -> float:
    difference = pre - tar
    row_average = np.mean(difference**2, axis=1)
    return row_average
